load "$NCARG_ROOT/lib/ncarg/nclex/gsun/gsn_code.ncl"
load "$NCARG_ROOT/lib/ncarg/nclscripts/csm/gsn_csm.ncl"
load "$NCARG_ROOT/lib/ncarg/nclscripts/csm/contributed.ncl"
load "$NCARG_ROOT/lib/ncarg/nclscripts/csm/shea_util.ncl"

; Plumb fluxes
; Follows derivation of https://paperpile.com/view/3fd7d9e7-351f-08da-ba1d-af2f7191ed86
; Calculated fluxes from climatological seasonal means
; with perturbations as differences from zonal mean
; Quasi-geostrophic, in spherical coordinates



undef("printPlumb2D")
function printPlumb2D(file_in,SF,lev_p,filo)
begin

    a = 6.37122e06  ; radius of Earth
    pi = 3.14159265358979
    omega =  7.2921e-5
    g = 9.80616
    P0 = 1000.0

    nlat = filevardimsizes(file_in,"lat")
    nlon = filevardimsizes(file_in,"lon")

    lat = tofloat(file_in->lat)
    lat!0="lat"
    lon = tofloat(file_in->lon)
    lon!0="lon"

    ; get dimension indices
    dimNames = getvardims(SF)
    ndims = dimsizes(dimNames)

    do idim =0,ndims-1
        if dimNames(idim) .EQ. "lat" then
            lat_idx = idim
        else if dimNames(idim) .EQ. "lon" then
            lon_idx = idim
        end if
        end if
    end do

    ; Calculate deviation streamfunction: effect of more ice

    PSIdev = dim_rmvmean_n(SF,lon_idx)

    ; Calculate lat and lon in radians
    latr = pi/180.0 * lat(:)
    lonr = pi/180.0 * lon(:)

    ; Calculate sin and cos of lat
    sinlat = conform_dims(dimsizes(SF),sin(latr),lat_idx)
    coslat = conform_dims(dimsizes(SF),cos(latr),lat_idx)

    ; Calculate coriolis frequency
    ftmp =  2.*2.*pi/(60.*60.*24.)*sin(latr)
    ftmp!0 = "lat"
    ftmp&lat = lat
    ftmp@_FillValue = SF@_FillValue

    ; missing for 10S - 10N
    do ilat = 0, nlat-1
        if (abs(lat(ilat) ).lt. 10. ) then
            ftmp(ilat)= ftmp@_FillValue
        end if
    end do

    f = conform_dims(dimsizes(SF),ftmp,lat_idx)

    ; Calculate various derivatives of PSIdev

    dPSIdevdlon = center_finite_diff_n(PSIdev,lonr,True,0,lon_idx)

    ddPSIdevdlonlon = center_finite_diff_n(dPSIdevdlon,lonr,True,0,lon_idx)

    dPSIdevdlat = center_finite_diff_n(PSIdev,latr,False,0,lat_idx)

    ddPSIdevdlatlat = center_finite_diff_n(dPSIdevdlat,latr,False,0,lat_idx)

    ddPSIdevdlonlat = center_finite_diff_n(dPSIdevdlon,latr,False,0,lat_idx)

    ; Calculate terms, taking cosphi inside the bracket mostly

    xterm1 = dPSIdevdlon * dPSIdevdlon
    xterm2 = PSIdev * ddPSIdevdlonlon

    yterm1 = dPSIdevdlat * dPSIdevdlon
    yterm2 = PSIdev * ddPSIdevdlonlat

    ; coeff normalized by 1000mb
    coeff = conform_dims(dimsizes(SF),lev_p/(1000.0 *2.0 * a * a),-1)

    ; Add together terms with appropriate multipliers, taking cosphi inside the bracket, and a2 outside
    ; Add together terms with appropriate multipliers, taking cosphi inside the bracket, and a2 outside
    ; Mask out where westerlies are small, or negative
    Fx = (coeff/coslat) * (xterm1 - xterm2)

    Fy = coeff * (yterm1 - yterm2)
    ; for output
    print(dimsizes(Fx))
    Fx!0 = "lat"
    Fx&lat = lat
    Fx!1 = "lon"
    Fx&lon = lon

    copy_VarMeta(Fx,Fy)
    copy_VarMeta(Fx,PSIdev)
    copy_VarMeta(Fx,SF)

    copy_VarMeta(Fx,dPSIdevdlat)
    copy_VarMeta(Fx,ddPSIdevdlonlon)
    copy_VarMeta(Fx,ddPSIdevdlonlat)

    Fx@units = "m^2/s^2"
    Fy@units = "m^2/s^2"
    PSIdev@units = "m^2/s"

    filo->Fx = Fx
    filo->Fy = Fy
    filo->PSIdev = PSIdev
    filo->SF = SF
    filo->dPSIdevdlon = dPSIdevdlon
    filo->dPSIdevdlat = dPSIdevdlat
    filo->ddPSIdevdlonlon = ddPSIdevdlonlon
    filo->ddPSIdevdlonlat = ddPSIdevdlonlat

    return(1)

end


undef("printPlumb")
function printPlumb(file_in,SF,filo)
begin

    a = 6.37122e06  ; radius of Earth
    pi = 3.14159265358979
    omega =  7.2921e-5
    g = 9.80616
    P0 = 1000.0

    nlat = filevardimsizes(file_in,"lat")
    nlon = filevardimsizes(file_in,"lon")

    lat = tofloat(file_in->lat)
    lat!0="lat"
    lon = tofloat(file_in->lon)
    lon!0="lon"

    ; get dimension indices
    dimNames = getvardims(SF)
    ndims = dimsizes(dimNames)

    do idim =0,ndims-1
        if dimNames(idim) .EQ. "lev_p" then
            lev_idx = idim
        else if dimNames(idim) .EQ. "lat" then
            lat_idx = idim
        else if dimNames(idim) .EQ. "lon" then
            lon_idx = idim
        end if
        end if
        end if
    end do

    ; Calculate deviation streamfunction: effect of more ice

    PSIdev = dim_rmvmean_n(SF,lon_idx)

    ; conform pressure levels
    level = file_in->lev_p
    levels = conform_dims(dimsizes(SF),level,lev_idx)
    
    ; Calculate lat and lon in radians
    latr = pi/180.0 * lat(:)
    lonr = pi/180.0 * lon(:)

    ; Calculate sin and cos of lat
    sinlat = conform_dims(dimsizes(SF),sin(latr),lat_idx)
    coslat = conform_dims(dimsizes(SF),cos(latr),lat_idx)

    ; Calculate coriolis frequency
    ftmp =  2.*2.*pi/(60.*60.*24.)*sin(latr)
    ftmp!0 = "lat"
    ftmp&lat = lat
    ftmp@_FillValue = SF@_FillValue

    ; missing for 10S - 10N
    do ilat = 0, nlat-1
        if (abs(lat(ilat) ).lt. 10. ) then
            ftmp(ilat)= ftmp@_FillValue
        end if
    end do

    f = conform_dims(dimsizes(SF),ftmp,lat_idx)

    ; Calculate buoyancy
    ;dthdz = center_finite_diff_n(TH,Z,False,0,lev_idx)

    ;NN = (g/TH) * dthdz

   ; Calculate various derivatives of PSIdev

    dPSIdevdlon = center_finite_diff_n(PSIdev,lonr,True,0,lon_idx)

    ddPSIdevdlonlon = center_finite_diff_n(dPSIdevdlon,lonr,True,0,lon_idx)

    dPSIdevdlat = center_finite_diff_n(PSIdev,latr,False,0,lat_idx)

    ddPSIdevdlatlat = center_finite_diff_n(dPSIdevdlat,latr,False,0,lat_idx)

    ddPSIdevdlonlat = center_finite_diff_n(dPSIdevdlon,latr,False,0,lat_idx)

    ; Different to original script: using Z instead of -scaleheight * log(level/1000)
    ;dPSIdevdz = center_finite_diff_n(PSIdev,Z,False,0,lev_idx)
    ;ddPSIdevdlonz = center_finite_diff_n(dPSIdevdlon,Z,False,0,lev_idx)
    ;ddPSIdevdlatz = center_finite_diff_n(dPSIdevdlat,Z,False,0,lev_idx)

    ; Calculate terms, taking cosphi inside the bracket mostly

    xterm1 = dPSIdevdlon * dPSIdevdlon
    xterm2 = PSIdev * ddPSIdevdlonlon

    yterm1 = dPSIdevdlat * dPSIdevdlon
    yterm2 = PSIdev * ddPSIdevdlonlat

    ; coeff normalized by 1000mb
    coeff = conform_dims(dimsizes(SF),level/(1000.0 *2.0 * a * a),lev_idx)

    ; Add together terms with appropriate multipliers, taking cosphi inside the bracket, and a2 outside
    ; Mask out where westerlies are small, or negative
    Fx = (coeff/coslat) * (xterm1 - xterm2)

    Fy = coeff * (yterm1 - yterm2)
    ; for output
    print(dimsizes(Fx))
    Fx!0 = "level"
    Fx&level = level
    Fx!1 = "lat"
    Fx&lat = lat
    Fx!2 = "lon"
    Fx&lon = lon

    copy_VarMeta(Fx,Fy)
    copy_VarMeta(Fx,PSIdev)
    copy_VarMeta(Fx,SF)

    copy_VarMeta(Fx,dPSIdevdlat)
    copy_VarMeta(Fx,ddPSIdevdlonlon)
    copy_VarMeta(Fx,ddPSIdevdlonlat)

    Fx@units = "m^2/s^2"
    Fy@units = "m^2/s^2"
    PSIdev@units = "m^2/s"

    filo->Fx = Fx
    filo->Fy = Fy
    filo->PSIdev = PSIdev
    filo->SF = SF
    filo->dPSIdevdlon = dPSIdevdlon
    filo->dPSIdevdlat = dPSIdevdlat
    filo->ddPSIdevdlonlon = ddPSIdevdlonlon
    filo->ddPSIdevdlonlat = ddPSIdevdlonlat

    return(1)

end


undef("printPlumb_time")
function printPlumb_time(file_in,SF_ts,lev_p,filo)
begin

    if lev_p > 1000
        print("lev_p must be in mb")
        exit
    end if


    a = 6.37122e06  ; radius of Earth
    pi = 3.14159265358979
    omega =  7.2921e-5
    g = 9.80616
    P0 = 1000.0

    nlat = filevardimsizes(file_in,"lat")
    nlon = filevardimsizes(file_in,"lon")

    lat = tofloat(file_in->lat)
    lat!0="lat"
    lon = tofloat(file_in->lon)
    lon!0="lon"

    ; get dimension indices
    dimNames = getvardims(SF_ts)
    ndims = dimsizes(dimNames)

    do idim =0,ndims-1
        if dimNames(idim) .EQ. "lat" then
            lat_idx = idim
        else if dimNames(idim) .EQ. "lon" then
            lon_idx = idim
        else if dimNames(idim) .EQ. "time" then
            time_idx = idim
        end if
        end if
        end if
    end do

    ; Calculate deviation streamfunction
    PSIdev_ts = dim_rmvmean_n(SF_ts,lon_idx)

    ; Calculate lat and lon in radians
    latr = pi/180.0 * lat(:)
    lonr = pi/180.0 * lon(:)

    ; Calculate sin and cos of lat
    sinlat = conform_dims(dimsizes(SF_ts),sin(latr),lat_idx)
    coslat = conform_dims(dimsizes(SF_ts),cos(latr),lat_idx)

    ; Calculate coriolis frequency
    ftmp =  2.*2.*pi/(60.*60.*24.)*sin(latr)
    ftmp!0 = "lat"
    ftmp&lat = lat
    ftmp@_FillValue = SF_ts@_FillValue

    ; missing for 10S - 10N
    do ilat = 0, nlat-1
        if (abs(lat(ilat) ).lt. 10. ) then
            ftmp(ilat)= ftmp@_FillValue
        end if
    end do

    f = conform_dims(dimsizes(SF_ts),ftmp,lat_idx)

    ; Calculate various derivatives of PSIdev

    dPSIdevdlon = center_finite_diff_n(PSIdev_ts,lonr,True,0,lon_idx)

    ddPSIdevdlonlon = center_finite_diff_n(dPSIdevdlon,lonr,True,0,lon_idx)

    dPSIdevdlat = center_finite_diff_n(PSIdev_ts,latr,False,0,lat_idx)

    ddPSIdevdlatlat = center_finite_diff_n(dPSIdevdlat,latr,False,0,lat_idx)

    ddPSIdevdlonlat = center_finite_diff_n(dPSIdevdlon,latr,False,0,lat_idx)

    ; Calculate terms, taking cosphi inside the bracket mostly

    xterm1 = dPSIdevdlon * dPSIdevdlon
    xterm2 = PSIdev_ts * ddPSIdevdlonlon

    yterm1 = dPSIdevdlat * dPSIdevdlon
    yterm2 = PSIdev_ts * ddPSIdevdlonlat

    ; coeff normalized by 1000mb
    coeff = conform_dims(dimsizes(SF_ts),lev_p/(1000.0 *2.0 * a * a),-1)

    ; Add together terms with appropriate multipliers, taking cosphi inside the bracket, and a2 outside
    ; Add together terms with appropriate multipliers, taking cosphi inside the bracket, and a2 outside
    ; Mask out where westerlies are small, or negative
    Fx_ts = (coeff/coslat) * (xterm1 - xterm2)

    Fy_ts = coeff * (yterm1 - yterm2)
    
    ; take mean over time dimension
    Fx = dim_avg_n_Wrap(Fx_ts,time_idx)
    Fy = dim_avg_n_Wrap(Fy_ts,time_idx)
    SF = dim_avg_n_Wrap(SF_ts,time_idx)
    PSIdev = dim_avg_n_Wrap(PSIdev_ts,time_idx)


    ; for output
    print(dimsizes(Fx))
    Fx!0 = "lat"
    Fx&lat = lat
    Fx!1 = "lon"
    Fx&lon = lon

    copy_VarMeta(Fx,Fy)
    copy_VarMeta(Fx,PSIdev)
    copy_VarMeta(Fx,SF)

    copy_VarMeta(Fx_ts,dPSIdevdlat)
    copy_VarMeta(Fx_ts,ddPSIdevdlonlon)
    copy_VarMeta(Fx_ts,ddPSIdevdlonlat)

    Fx@units = "m^2/s^2"
    Fy@units = "m^2/s^2"
    PSIdev@units = "m^2/s"

    filo->Fx = Fx
    filo->Fy = Fy
    filo->PSIdev = PSIdev
    filo->SF = SF
    filo->dPSIdevdlon = dPSIdevdlon
    filo->dPSIdevdlat = dPSIdevdlat
    filo->ddPSIdevdlonlon = ddPSIdevdlonlon
    filo->ddPSIdevdlonlat = ddPSIdevdlonlat

    return(1)

end

